import numpy as np
import networkx as nx

import pickle
import os

import cdlib


if __name__ == "__main__":
    
    n_graphs = 10 #number of graphs to generate
    g = 500 #number of nodes per cluster

    k=5
    k_outs = np.linspace(0.02, 0.9, 20)
    k_ins = k-k_outs
    
    p_ins = k_ins/g
    p_outs = k_outs/g


    if os.path.isfile('SBM_benchmark.pkl'):
        res = pickle.load(open('SBM_benchmark.pkl','rb'))
    else:
        res = {'graphs': [],
               'p_in': [],
               'p_out': [],
               'k': [],
               'n': [],
               'ground_truth': []}
        
    p_ins = np.tile(p_ins,n_graphs)
    p_outs = np.tile(p_outs,n_graphs)
    
    #generate graphs
    for p_in, p_out in zip(p_ins, p_outs):
        print(p_in,p_out)
        
        #generate graph
        G = nx.planted_partition_graph(2, g, p_in, p_out, seed=len(res['graphs']))
        largest_cc = max(nx.connected_components(G), key=len)
        G = G.subgraph(largest_cc)
        G = nx.convert_node_labels_to_integers(G)
        C_1 = [n for n in G.nodes if G.nodes[n]["block"] == 0]
        C_2 = [n for n in G.nodes if G.nodes[n]["block"] == 1]
        ground_truth = cdlib.NodeClustering(communities=[C_1,C_2], graph=G, method_name="reference")
        
        #collect metadata
        res['graphs'].append(G)
        res['p_in'].append(p_in)
        res['p_out'].append(p_out)
        res['k'].append(k)
        res['n'].append(2*g)
        res['ground_truth'].append(ground_truth)
                
    #saving
    pickle.dump(res, open('SBM_benchmark.pkl','wb'))